package afeng.world.pangu.tools;

import afeng.world.pangu.common.PanguConstants;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;

import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.*;

/**
 * 1.计算损失
 * 2.固化一定范围最小损失 假设当前最小,到label宽度的2/3范围没有最小就固化, 调整损失点和损失值
 * 3.清洗固化的损失
 */
@Slf4j
public class ImageDetection {

    private int cpuNum = Runtime.getRuntime().availableProcessors();
    private ExecutorService executorService = new ThreadPoolExecutor(cpuNum, cpuNum, 60L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<Runnable>());

    //损失
    private List<DetectTarget> detectTargetList = new ArrayList<>();
    private double currentMinLoss = -1;
    private Point currentLossPoint = new Point();
    private volatile String currentType = null;

    private final String detectImagePath;
    private final String labelImagePath_shang;
    private final String labelImagePath_xia;
    private final String labelImagePath_zuo;
    private final String labelImagePath_you;
    private final String labelImagePath_k;
    private final String labelImagePath_j;
    private final BufferedImage label_shang;
    private final BufferedImage label_xia;
    private final BufferedImage label_zuo;
    private final BufferedImage label_you;
    private final BufferedImage label_j;
    private final BufferedImage label_k;
    private final BufferedImage detectImage;
    private final int detectWith;
    private final int detectHeight;
    private int labelMinWith;
    private int labelMinHeight;
    private int labelMaxWith;
    private int labelMaxHeight;
    public ImageDetection(String detectImagePath, String labelImagePath_shang, String labelImagePath_xia, String labelImagePath_zuo, String labelImagePath_you, String labelImagePath_k, String labelImagePath_j) {
        this.detectImagePath = detectImagePath;
        this.labelImagePath_shang = labelImagePath_shang;
        this.labelImagePath_xia = labelImagePath_xia;
        this.labelImagePath_zuo = labelImagePath_zuo;
        this.labelImagePath_you = labelImagePath_you;
        this.labelImagePath_k = labelImagePath_k;
        this.labelImagePath_j = labelImagePath_j;
        this.detectImage = PictureTools.getBufferedImage(detectImagePath);
        this.label_shang = PictureTools.getBufferedImage(labelImagePath_shang);
        this.label_xia = PictureTools.getBufferedImage(labelImagePath_xia);
        this.label_zuo = PictureTools.getBufferedImage(labelImagePath_zuo);
        this.label_you = PictureTools.getBufferedImage(labelImagePath_you);
        this.label_j = PictureTools.getBufferedImage(labelImagePath_j);
        this.label_k = PictureTools.getBufferedImage(labelImagePath_k);
        this.detectWith = detectImage.getWidth();
        this.detectHeight = detectImage.getHeight();
        setMaxMinSize();
        loadFutureData();
    }

    private void loadFutureData() {
        int width = label_shang.getWidth();
        int height = label_shang.getHeight();
        // 12六个样本,每组2个, with宽, height高, 3个像素值
        int[][][][] data = new int[12][width][height][3];

        for (int i = 0; i < width; i++) {
            for (int j = 0; j < height; j++) {
                int shangRGB = label_shang.getRGB(i, j);
                int lr = (shangRGB >> 16 & 0xff);
                int lg = (shangRGB >> 8 & 0xff);
                int lb = (shangRGB & 0xff);
                data[1][i][j][0] = lr;
                data[1][i][j][1] = lg;
                data[1][i][j][2] = lb;
            }
        }
        //数组序列化反序列 todo
    }

    /**
     * 计算标本图片最值
     */
    private void setMaxMinSize() {
        this.labelMinWith = label_shang.getWidth();
        if (labelMinWith > label_xia.getWidth()) {
            labelMinWith = label_xia.getWidth();
        }
        if (labelMinWith > label_zuo.getWidth()) {
            labelMinWith = label_zuo.getWidth();
        }
        if (labelMinWith > label_you.getWidth()) {
            labelMinWith = label_you.getWidth();
        }
        if (labelMinWith > label_k.getWidth()) {
            labelMinWith = label_k.getWidth();
        }
        if (labelMinWith > label_j.getWidth()) {
            labelMinWith = label_j.getWidth();
        }
        this.labelMaxWith = label_shang.getWidth();
        if (labelMaxWith < label_xia.getWidth()) {
            labelMaxWith = label_xia.getWidth();
        }
        if (labelMaxWith < label_zuo.getWidth()) {
            labelMaxWith = label_zuo.getWidth();
        }
        if (labelMaxWith < label_you.getWidth()) {
            labelMaxWith = label_you.getWidth();
        }
        if (labelMaxWith < label_k.getWidth()) {
            labelMaxWith = label_k.getWidth();
        }
        if (labelMaxWith < label_j.getWidth()) {
            labelMaxWith = label_j.getWidth();
        }

        this.labelMinHeight = label_shang.getHeight();
        if (labelMinHeight > label_xia.getHeight()) {
            labelMinHeight = label_xia.getHeight();
        }
        if (labelMinHeight > label_zuo.getHeight()) {
            labelMinHeight = label_zuo.getHeight();
        }
        if (labelMinHeight > label_you.getHeight()) {
            labelMinHeight = label_you.getHeight();
        }
        if (labelMinHeight > label_k.getHeight()) {
            labelMinHeight = label_k.getHeight();
        }
        if (labelMinHeight > label_j.getHeight()) {
            labelMinHeight = label_j.getHeight();
        }
        this.labelMaxHeight = label_shang.getHeight();
        if (labelMaxHeight < label_xia.getHeight()) {
            labelMaxHeight = label_xia.getHeight();
        }
        if (labelMaxHeight < label_zuo.getHeight()) {
            labelMaxHeight = label_zuo.getHeight();
        }
        if (labelMaxHeight < label_you.getHeight()) {
            labelMaxHeight = label_you.getHeight();
        }
        if (labelMaxHeight < label_k.getHeight()) {
            labelMaxHeight = label_k.getHeight();
        }
        if (labelMaxHeight < label_j.getHeight()) {
            labelMaxHeight = label_j.getHeight();
        }
    }

    public void handle() {
        long startTime = System.currentTimeMillis();
        /**
         * 扫描大全图
         * 21:18:34.866 [main] INFO afeng.world.pangu.tools.PictureTools - 耗时-->247457  [x=1040,y=980]
         * 21:18:34.866 [main] INFO afeng.world.pangu.tools.PictureTools - 目标坐标-->java.awt.Point[x=1040,y=976]
         */
        Future<?> submit = null;
        for (int x = 0; x < detectWith - labelMaxWith; ) {
            int finalX = x;
            submit = executorService.submit(() -> {
                log.info("提交线程扫描行-->" + finalX );
                for (int y = 0; y < detectHeight - labelMaxHeight; ) {
                    verifyImage(finalX, y);
                    y += PanguConstants.Y_pixel;
                }
            });
            x += PanguConstants.X_pixel;  // 二十倍数,提升速度
        }
        try {
            submit.get();
            executorService.shutdownNow();
            Graphics graphics = detectImage.getGraphics();
            graphics.setColor(Color.RED);
            List<DetectTarget> list = new ArrayList<>();
            int count = 0;
            DetectTarget detectTargetOld = null;
            for (DetectTarget detectTarget : detectTargetList) {
                //清洗
                if (count == 0) {
                    count++;
                    detectTargetOld = detectTarget;
                    continue;
                }
                //方案一
                int distanceX = detectTarget.getTargetPoint().x - detectTargetOld.getTargetPoint().x;
                int distanceY = detectTarget.getTargetPoint().y - detectTargetOld.getTargetPoint().y;
                if (distanceX > labelMinWith * 1/2 && distanceX < labelMaxWith * 3 / 2
//                        && distanceY > -(labelMinHeight / 3) && distanceY < labelMinHeight / 3
                ) {
                    list.add(detectTarget);
//                    double lossCmp = detectTarget.getLoss() - detectTargetOld.getLoss();
//                    System.out.println(lossCmp);
                }
                //方案二 损失值比较
                detectTargetOld = detectTarget;
            }
            for (DetectTarget detectTarget : list) {
                graphics.drawRect(detectTarget.getTargetPoint().x, detectTarget.getTargetPoint().y, labelMinWith, labelMinHeight);
                log.info("目标坐标-->" + detectTarget.getTargetPoint()+"--->图"+detectTarget.getType());
            }
            ImageIO.write(detectImage, "png", new File("target.png"));
            log.info("耗时-->" + (System.currentTimeMillis() - startTime));
        } catch (IOException | InterruptedException | ExecutionException e) {
            log.error(e.getMessage());
        }
    }

    /**
     * 扫描框定图比对
     */
    private void verifyImage(int x, int y) {
        if (detectWith < x + labelMaxWith || detectHeight < y + labelMaxHeight) {
            return;
        }
        /**
         * 开始对比图片相似度
         */
        double loss_shang = 0, loss_xia = 0, loss_zuo = 0, loss_you = 0, loss_k = 0, loss_j = 0;
        for (int i = x; i < x + labelMinWith; i++) {
            for (int j = y; j < y + labelMinHeight; j++) {

                int d = detectImage.getRGB(i, j);
                int dr = (d >> 16 & 0xff);
                int dg = (d >> 8 & 0xff);
                int db = (d & 0xff);
                //把特征做成数据 todo
                loss_shang += claculateLoss(i, j, x, y, label_shang, dr, dg, db);
                loss_xia += claculateLoss(i, j, x, y, label_xia, dr, dg, db);
                loss_zuo += claculateLoss(i, j, x, y, label_zuo, dr, dg, db);
                loss_you += claculateLoss(i, j, x, y, label_you, dr, dg, db);
                loss_k += claculateLoss(i, j, x, y, label_k, dr, dg, db);
                loss_j += claculateLoss(i, j, x, y, label_j, dr, dg, db);
            }
        }
        //(x,y)点的损失
        double loss = loss_shang;
        String type = "shang";
        if (loss > loss_xia) { //拿最小
            loss = loss_xia;
            type = "xia";
        }
        if (loss > loss_zuo) {
            loss = loss_zuo;
            type = "zuo";
        }
        if (loss > loss_you) {
            loss = loss_you;
            type = "you";
        }
        if (loss > loss_k) {
            loss = loss_k;
            type = "k";
        }
        if (loss > loss_j) {
            loss = loss_j;
            type = "j";
        }

        synchronized (this) {
            if (loss < currentMinLoss) {
                currentMinLoss = loss;
                currentType = type;
                currentLossPoint = new Point(x, y);
            } else if (currentMinLoss == -1) {
                //初始化
                currentMinLoss = loss;
                currentType = type;
                currentLossPoint = new Point(x, y);
            }

            if (x - currentLossPoint.x > labelMinWith * 2 / 3) { //多线程下几率出问题
                //固化
                DetectTarget detectTarget = new DetectTarget(currentLossPoint, labelMinWith, labelMinHeight, currentMinLoss, currentType);
                detectTargetList.add(detectTarget);
                //重置
                currentMinLoss = loss;
                currentType = type;
                currentLossPoint = new Point(x, y);
            }
        }
    }

    /**
     * 计算损失
     *
     * @param xx    yy 移动的小矩阵坐标
     * @param x
     * @param y
     * @param label 目标图片
     * @param
     * @return
     */
    public double claculateLoss(int xx, int yy, int x, int y, BufferedImage label, int dr, int dg, int db) {
        int kRGB = label.getRGB(xx - x, yy - y);
        int lr = (kRGB >> 16 & 0xff);
        int lg = (kRGB >> 8 & 0xff);
        int lb = (kRGB & 0xff);

        return Math.pow(dr - lr, 2) + Math.pow(dg - lg, 2) + Math.pow(db - lb, 2);
    }

    /**
     * 测试
     *
     * @param args
     */
    public static void main(String[] args) {
        String detectImagePath = "asset/img/detect/detectImage5.png";
        String shang = "asset/img/detect/shang.png";
        String xia = "asset/img/detect/xia.png";
        String zuo = "asset/img/detect/zuo.png";
        String you = "asset/img/detect/you.png";
        String k = "asset/img/detect/k.png";
        String j = "asset/img/detect/j.png";
        ImageDetection imageDetection = new ImageDetection(detectImagePath, shang, xia, zuo, you, k, j);
        imageDetection.handle();
    }

    /**
     * 检测目标信息
     */
    @Data
    @AllArgsConstructor
    class DetectTarget {
        //损失点, 一个损失点代表一个框, 同目标比较相似度
        Point targetPoint;
        int with;
        int height;
        //损失值
        double loss;
        String type;
    }
}